#!/usr/bin/env python3
"""
Phase 3: Demographics Mediating Twin Deficits
Research question F: Does the fiscal–CA link depend on demographics?

Models:
  1. Baseline twin deficits: ca_gdp = β·fiscal + controls
  2. + Z controls: does β change?
  3. Triple: fiscal × Z₁ interaction
  4. fiscal × old_dep simpler interaction
  5. fiscal × old_dep × kaopen triple
  6. Subsample by demographic tercile
  7. Pension system moderator (OECD subsample)
  8. Common cause test: fiscal residualized on Z

Output: extensions/output/tables/twin_deficits.md
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys

# ── Paths ──────────────────────────────────────────────────────────────────
PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"

OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(MULTILATERAL_DIR / "src"))
from model import PanelGLS


def load_panel():
    path = PROJECT_DIR / "data" / "processed" / "extensions_panel.csv"
    df = pd.read_csv(path)
    df = df[df["year"] <= 2024].copy()
    return df


def stars(p):
    if p < 0.01:
        return "***"
    elif p < 0.05:
        return "**"
    elif p < 0.1:
        return "*"
    return ""


def run_model(df, dep_var, regressors, label):
    """Run PanelGLS and return results dict."""
    cols = [dep_var] + regressors
    sub = df.dropna(subset=cols).copy()
    if len(sub) < 100:
        print(f"  SKIP {label}: only {len(sub)} obs")
        return None

    model = PanelGLS()
    model.fit(sub[dep_var].values, sub[regressors].values,
              sub["iso3"].values, sub["year"].values)
    print(f"\n  {label}")
    model.summary(feature_names=regressors)

    res = model.to_dataframe(feature_names=regressors)
    res["model"] = label
    res["dep_var"] = dep_var
    res["n_obs"] = model.n_obs
    res["n_countries"] = model.n_countries
    res["r_squared"] = model.r_squared
    res["rho"] = model.rho
    return res


def main():
    print("=" * 70)
    print("PHASE 3: Demographics Mediating Twin Deficits")
    print("=" * 70)

    df = load_panel()

    demo_vars = ["Z_1", "Z_2", "Z_3"]
    non_fiscal_controls = ["nfa_gdp_lag", "log_rel_opw", "kaopen"]
    non_fiscal_controls = [c for c in non_fiscal_controls
                           if c in df.columns and df[c].notna().sum() > 200]

    all_results = []

    # ── Model 1: Baseline twin deficits (no Z) ────────────────────────────
    m1_vars = ["fiscal_bal_gdp"] + non_fiscal_controls
    r = run_model(df, "ca_gdp", m1_vars, "M1: Twin deficits (no Z)")
    if r is not None:
        all_results.append(r)
        fiscal_coef_no_z = r.loc[r["variable"] == "fiscal_bal_gdp",
                                  "coefficient"].values[0]
        print(f"  → fiscal coef without Z: {fiscal_coef_no_z:.4f}")

    # ── Model 2: Twin deficits + Z controls ────────────────────────────────
    m2_vars = ["fiscal_bal_gdp"] + non_fiscal_controls + demo_vars
    r = run_model(df, "ca_gdp", m2_vars, "M2: Twin deficits + Z")
    if r is not None:
        all_results.append(r)
        fiscal_coef_with_z = r.loc[r["variable"] == "fiscal_bal_gdp",
                                    "coefficient"].values[0]
        print(f"  → fiscal coef with Z: {fiscal_coef_with_z:.4f}")
        if all_results[0] is not None:
            change = (fiscal_coef_with_z - fiscal_coef_no_z) / fiscal_coef_no_z * 100
            print(f"  → Change: {change:.1f}%")

    # ── Model 3: Triple interaction fiscal × Z₁ ───────────────────────────
    for z_int in ["fiscal_x_Z_1", "fiscal_x_Z_2", "fiscal_x_Z_3"]:
        if z_int not in df.columns:
            continue
    m3_vars = (["fiscal_bal_gdp"] + non_fiscal_controls + demo_vars
               + ["fiscal_x_Z_1", "fiscal_x_Z_2", "fiscal_x_Z_3"])
    m3_vars = [v for v in m3_vars if v in df.columns]
    r = run_model(df, "ca_gdp", m3_vars,
                  "M3: fiscal × Z interaction")
    if r is not None:
        all_results.append(r)

    # ── Model 4: fiscal × old_dep (simpler) ────────────────────────────────
    m4_vars = (["fiscal_bal_gdp", "old_dep", "fiscal_x_old_dep"]
               + non_fiscal_controls)
    m4_vars = [v for v in m4_vars if v in df.columns]
    r = run_model(df, "ca_gdp", m4_vars,
                  "M4: fiscal × old_dep")
    if r is not None:
        all_results.append(r)

    # ── Model 5: fiscal × old_dep × kaopen triple ─────────────────────────
    df["fiscal_x_od_x_kaopen"] = (df["fiscal_bal_gdp"] * df["old_dep"]
                                   * df["kaopen"])
    m5_vars = (["fiscal_bal_gdp", "old_dep", "kaopen",
                "fiscal_x_old_dep", "fiscal_x_od_x_kaopen"]
               + [c for c in non_fiscal_controls if c != "kaopen"])
    m5_vars = [v for v in m5_vars if v in df.columns]
    r = run_model(df, "ca_gdp", m5_vars,
                  "M5: fiscal × old_dep × kaopen")
    if r is not None:
        all_results.append(r)

    # ── Model 6: Subsample by demographic tercile ─────────────────────────
    for tercile, label in [(1, "early"), (2, "mid"), (3, "late")]:
        sub = df[df["demo_tercile"] == tercile].copy()
        m6_vars = ["fiscal_bal_gdp"] + non_fiscal_controls + demo_vars
        r = run_model(sub, "ca_gdp", m6_vars,
                      f"M6{chr(96+tercile)}: Tercile {tercile} ({label})")
        if r is not None:
            all_results.append(r)

    # ── Model 7: Pension moderator (OECD subsample) ────────────────────────
    if "pension_spending_gdp" in df.columns:
        df["fiscal_x_pension"] = df["fiscal_bal_gdp"] * df["pension_spending_gdp"]
        m7_vars = (["fiscal_bal_gdp", "pension_spending_gdp",
                    "fiscal_x_pension"]
                   + non_fiscal_controls + demo_vars)
        m7_vars = [v for v in m7_vars if v in df.columns]
        r = run_model(df, "ca_gdp", m7_vars,
                      "M7: Pension moderator")
        if r is not None:
            all_results.append(r)

    # ── Model 8: Common cause test ─────────────────────────────────────────
    # Regress fiscal on Z, take residual, re-run twin deficit with residual
    print("\n  --- Common cause test ---")
    fiscal_z_vars = demo_vars + [c for c in non_fiscal_controls
                                  if c != "kaopen"]
    sub_cc = df.dropna(subset=["fiscal_bal_gdp", "ca_gdp"] + fiscal_z_vars
                       + ["kaopen"]).copy()
    if len(sub_cc) >= 200:
        # Stage 1: fiscal = f(Z)
        s1 = PanelGLS()
        s1.fit(sub_cc["fiscal_bal_gdp"].values,
               sub_cc[fiscal_z_vars].values,
               sub_cc["iso3"].values, sub_cc["year"].values)
        print(f"\n  S1: Z → fiscal_bal_gdp")
        s1.summary(feature_names=fiscal_z_vars)

        sub_cc["fiscal_resid"] = s1.resid

        # Stage 2: CA = f(fiscal_resid, controls)
        m8_vars = ["fiscal_resid"] + non_fiscal_controls + demo_vars
        r = run_model(sub_cc, "ca_gdp", m8_vars,
                      "M8: CA ~ fiscal_resid + Z")
        if r is not None:
            all_results.append(r)
            fiscal_resid_coef = r.loc[r["variable"] == "fiscal_resid",
                                       "coefficient"].values[0]
            print(f"  → fiscal_resid coef: {fiscal_resid_coef:.4f}")
            print(f"  → If close to zero: demographics are common cause")

    # ── Combine and save ───────────────────────────────────────────────────
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(OUT_TABLES / "twin_deficits_results.csv",
                          index=False)
        write_markdown_table(results_df)

    print("\n  Done. Results saved to extensions/output/tables/")


def write_markdown_table(results_df):
    """Write formatted markdown results table."""
    lines = ["# Twin Deficits & Demographics", ""]

    models = results_df["model"].unique()
    lines.append("## Model Comparison")
    lines.append("")
    lines.append("| Model | N | Countries | R² | ρ |")
    lines.append("|-------|---|-----------|----|----|")
    for m in models:
        sub = results_df[results_df["model"] == m].iloc[0]
        lines.append(f"| {m} | {sub['n_obs']:,} | "
                     f"{sub['n_countries']} | {sub['r_squared']:.3f} | "
                     f"{sub['rho']:.3f} |")

    # Fiscal coefficient across models
    lines.append("")
    lines.append("## Fiscal Coefficient Across Specifications")
    lines.append("")
    lines.append("| Model | fiscal coef | SE | p-value | Interpretation |")
    lines.append("|-------|-------------|-----|---------|----------------|")
    for m in models:
        sub = results_df[results_df["model"] == m]
        for fvar in ["fiscal_bal_gdp", "fiscal_resid"]:
            row = sub[sub["variable"] == fvar]
            if len(row) == 0:
                continue
            row = row.iloc[0]
            sig = stars(row["p_value"]) if pd.notna(row["p_value"]) else ""
            lines.append(f"| {m} | {row['coefficient']:.4f}{sig} | "
                         f"{row['std_error']:.4f} | {row['p_value']:.4f} | |")

    # Interaction terms
    lines.append("")
    lines.append("## Interaction Terms")
    lines.append("")
    int_vars = ["fiscal_x_Z_1", "fiscal_x_Z_2", "fiscal_x_Z_3",
                "fiscal_x_old_dep", "fiscal_x_od_x_kaopen",
                "fiscal_x_pension"]
    lines.append("| Variable | Model | Coef | SE | p-value |")
    lines.append("|----------|-------|------|----|---------|")
    for v in int_vars:
        sub = results_df[results_df["variable"] == v]
        for _, row in sub.iterrows():
            if pd.isna(row["p_value"]):
                continue
            sig = stars(row["p_value"])
            lines.append(f"| {v} | {row['model']} | "
                         f"{row['coefficient']:.4f}{sig} | "
                         f"{row['std_error']:.4f} | {row['p_value']:.4f} |")

    md_path = OUT_TABLES / "twin_deficits.md"
    md_path.write_text("\n".join(lines))
    print(f"\n  Markdown table: {md_path}")


if __name__ == "__main__":
    main()
